#include "halide_benchmark.h"
#include <assert.h>
#include <memory.h>
#include <stdio.h>
#include <stdlib.h>
#ifdef SCHEDULE_ALL
#include "pipeline_nv12_linear_ro_async.h"
#include "pipeline_nv12_linear_ro_basic.h"
#include "pipeline_nv12_linear_ro_fold.h"
#include "pipeline_nv12_linear_ro_split.h"
#include "pipeline_nv12_linear_ro_split_async.h"

#include "pipeline_nv12_linear_rw_basic.h"
#include "pipeline_nv12_linear_rw_fold.h"
#endif
#include "pipeline_nv12_linear_rw_async.h"
#include "pipeline_nv12_linear_rw_split.h"
#include "pipeline_nv12_linear_rw_split_async.h"
#ifdef SCHEDULE_ALL
#include "pipeline_p010_linear_ro_async.h"
#include "pipeline_p010_linear_ro_basic.h"
#include "pipeline_p010_linear_ro_fold.h"
#include "pipeline_p010_linear_ro_split.h"
#include "pipeline_p010_linear_ro_split_async.h"

#include "pipeline_p010_linear_rw_basic.h"
#include "pipeline_p010_linear_rw_fold.h"
#endif
#include "HalideBuffer.h"
#include "HalideRuntimeHexagonDma.h"
#include "pipeline_p010_linear_rw_async.h"
#include "pipeline_p010_linear_rw_split.h"
#include "pipeline_p010_linear_rw_split_async.h"

enum {
    SCHEDULE_BASIC,
    SCHEDULE_FOLD,
    SCHEDULE_ASYNC,
    SCHEDULE_SPLIT,
    SCHEDULE_SPLIT_ASYNC,
    SCHEDULE_MAX
};

enum {
    DIRECTION_RW,
    DIRECTION_RO,
    DIRECTION_MAX
};

typedef struct {
    const char *schedule_name;
    int (*schedule_call)(struct halide_buffer_t *in_y, struct halide_buffer_t *in_uv, struct halide_buffer_t *out_y, struct halide_buffer_t *out_uv);
} ScheduleList;

#define _SCHEDULE_STR(s) #s
#define _SCHEDULE_NAME(data, direction, schedule) pipeline_##data##_##direction##_##schedule
#define _SCHEDULE_PAIR(data, direction, schedule) \
    { _SCHEDULE_STR(scheduled - pipeline(data, direction, schedule)), _SCHEDULE_NAME(data, direction, schedule) }
#define _SCHEDULE_DUMMY_PAIR \
    { NULL, NULL }
#define SCHEDULE_FUNCTION_RW(type, schedule) _SCHEDULE_PAIR(type##_linear, rw, schedule)

#ifdef SCHEDULE_ALL
#define SCHEDULE_FUNCTION_RO(type, schedule) _SCHEDULE_PAIR(type##_linear, ro, schedule)
#else
#define SCHEDULE_FUNCTION_RO(type, schedule) _SCHEDULE_DUMMY_PAIR
#endif

static ScheduleList schedule_listNV12[DIRECTION_MAX][SCHEDULE_MAX] = {{
#ifdef SCHEDULE_ALL
                                                                          SCHEDULE_FUNCTION_RW(nv12, basic),
                                                                          SCHEDULE_FUNCTION_RW(nv12, fold),
#else
                                                                          SCHEDULE_FUNCTION_RO(nv12, basic),  // dummy
                                                                          SCHEDULE_FUNCTION_RO(nv12, fold),   // dummy
#endif
                                                                          SCHEDULE_FUNCTION_RW(nv12, async),
                                                                          SCHEDULE_FUNCTION_RW(nv12, split),
                                                                          SCHEDULE_FUNCTION_RW(nv12, split_async)},
                                                                      {SCHEDULE_FUNCTION_RO(nv12, basic),
                                                                       SCHEDULE_FUNCTION_RO(nv12, fold),
                                                                       SCHEDULE_FUNCTION_RO(nv12, async),
                                                                       SCHEDULE_FUNCTION_RO(nv12, split),
                                                                       SCHEDULE_FUNCTION_RO(nv12, split_async)}};

static ScheduleList schedule_listP010[DIRECTION_MAX][SCHEDULE_MAX] = {{
#ifdef SCHEDULE_ALL
                                                                          SCHEDULE_FUNCTION_RW(p010, basic),
                                                                          SCHEDULE_FUNCTION_RW(p010, fold),
#else
                                                                          SCHEDULE_FUNCTION_RO(p010, basic),  // dummy
                                                                          SCHEDULE_FUNCTION_RO(p010, fold),   // dummy
#endif
                                                                          SCHEDULE_FUNCTION_RW(p010, async),
                                                                          SCHEDULE_FUNCTION_RW(p010, split),
                                                                          SCHEDULE_FUNCTION_RW(p010, split_async)},
                                                                      {SCHEDULE_FUNCTION_RO(p010, basic),
                                                                       SCHEDULE_FUNCTION_RO(p010, fold),
                                                                       SCHEDULE_FUNCTION_RO(p010, async),
                                                                       SCHEDULE_FUNCTION_RO(p010, split),
                                                                       SCHEDULE_FUNCTION_RO(p010, split_async)}};

template<typename T, size_t size_direction, size_t size_schedule>
inline int process_pipeline(T const &type, const int width, const int height,
                            const char *schedule, const char *dma_direction,
                            ScheduleList (&schedule_list)[size_direction][size_schedule]) {
    int ret = 0;

    // Fill the input buffer with random test data. This is just a plain old memory buffer
    const int buf_size = (width * height * 3) / 2;
    T *data_in = (T *)malloc(buf_size * sizeof(T));
    T *data_out = (T *)malloc(buf_size * sizeof(T));
    // Creating the Input Data so that we can catch if there are any Errors in DMA
    for (int i = 0; i < buf_size; i++) {
        data_in[i] = ((T)rand()) >> 1;
        data_out[i] = 0;
    }

    // Setup Halide input buffer with the test buffer
    Halide::Runtime::Buffer<T> input_validation(data_in, width, height, 2);
    Halide::Runtime::Buffer<T> input(nullptr, width, (3 * height) / 2);
    Halide::Runtime::Buffer<T> input_y = input.cropped(1, 0, height);            // Luma plane only
    Halide::Runtime::Buffer<T> input_uv = input.cropped(1, height, height / 2);  // Chroma plane only, with reduced height

    // describe the UV interleaving for 4:2:0 format
    input_uv.embed(2, 0);
    input_uv.raw_buffer()->dim[2].extent = 2;
    input_uv.raw_buffer()->dim[2].stride = 1;
    input_uv.raw_buffer()->dim[0].stride = 2;
    input_uv.raw_buffer()->dim[0].extent = width / 2;

    // Setup Halide output buffer
    Halide::Runtime::Buffer<T> output(width, (3 * height) / 2);
    Halide::Runtime::Buffer<T> output_y = output.cropped(1, 0, height);              // Luma plane only
    Halide::Runtime::Buffer<T> output_uv = output.cropped(1, height, (height / 2));  // Chroma plane only, with reduced height

    // describe the UV interleaving for 4:2:0 format
    output_uv.embed(2, 0);
    output_uv.raw_buffer()->dimensions = 3;
    output_uv.raw_buffer()->dim[2].extent = 2;
    output_uv.raw_buffer()->dim[2].stride = 1;
    output_uv.raw_buffer()->dim[0].stride = 2;
    output_uv.raw_buffer()->dim[0].extent = width / 2;

    // DMA_step 1: Assign buffer to DMA interface
    input_y.device_wrap_native(halide_hexagon_dma_device_interface(), reinterpret_cast<uint64_t>(data_in));
    input_uv.device_wrap_native(halide_hexagon_dma_device_interface(), reinterpret_cast<uint64_t>(data_in));
    input_y.set_device_dirty();
    input_uv.set_device_dirty();

    if (!strcmp(dma_direction, "rw")) {
        output_y.device_wrap_native(halide_hexagon_dma_device_interface(), reinterpret_cast<uint64_t>(data_out));
        output_uv.device_wrap_native(halide_hexagon_dma_device_interface(), reinterpret_cast<uint64_t>(data_out));
        output_y.set_device_dirty();
        output_uv.set_device_dirty();
    }

    // DMA_step 2: Allocate a DMA engine
    void *dma_engine = nullptr;
    void *dma_engine_write = nullptr;
    halide_hexagon_dma_allocate_engine(nullptr, &dma_engine);

    if ((!strcmp(schedule, "async") || !strcmp(schedule, "split_async")) && !strcmp(dma_direction, "rw")) {
        printf("A separate engine for DMA write\n");
        halide_hexagon_dma_allocate_engine(nullptr, &dma_engine_write);
    }

    halide_hexagon_image_fmt_t fmt_y = (sizeof(type) == 1) ? halide_hexagon_fmt_NV12_Y : halide_hexagon_fmt_P010_Y;
    halide_hexagon_image_fmt_t fmt_uv = (sizeof(type) == 1) ? halide_hexagon_fmt_NV12_UV : halide_hexagon_fmt_P010_UV;

    // DMA_step 3: Associate buffer to DMA engine, and prepare for copying to host (DMA read) and device (DMA write)
    halide_hexagon_dma_prepare_for_copy_to_host(nullptr, input_y, dma_engine, false, fmt_y);
    halide_hexagon_dma_prepare_for_copy_to_host(nullptr, input_uv, dma_engine, false, fmt_uv);
    if (!strcmp(dma_direction, "rw")) {
        if (!strcmp(schedule, "async") || !strcmp(schedule, "split_async")) {
            printf("Use separate engine for DMA output\n");
            halide_hexagon_dma_prepare_for_copy_to_device(nullptr, output_y, dma_engine_write, false, fmt_y);
            halide_hexagon_dma_prepare_for_copy_to_device(nullptr, output_uv, dma_engine_write, false, fmt_uv);
        } else {
            halide_hexagon_dma_prepare_for_copy_to_device(nullptr, output_y, dma_engine, false, fmt_y);
            halide_hexagon_dma_prepare_for_copy_to_device(nullptr, output_uv, dma_engine, false, fmt_uv);
        }
    }

    int my_direction = (!strcmp(dma_direction, "rw")) ? DIRECTION_RW : DIRECTION_RO;
    int my_schedule = SCHEDULE_MAX;
    if (!strcmp(schedule, "basic")) {
        my_schedule = SCHEDULE_BASIC;
    } else if (!strcmp(schedule, "fold")) {
        my_schedule = SCHEDULE_FOLD;
    } else if (!strcmp(schedule, "async")) {
        my_schedule = SCHEDULE_ASYNC;
    } else if (!strcmp(schedule, "split")) {
        my_schedule = SCHEDULE_SPLIT;
    } else if (!strcmp(schedule, "split_async")) {
        my_schedule = SCHEDULE_SPLIT_ASYNC;
    }
    if (my_schedule < SCHEDULE_MAX) {
        if (schedule_list[my_direction][my_schedule].schedule_name != NULL) {
            printf("%s\n", schedule_list[my_direction][my_schedule].schedule_name);
            ret = (*schedule_list[my_direction][my_schedule].schedule_call)(input_y, input_uv, output_y, output_uv);
        } else {
            printf("Schedule pipeline test not built-in (%s, %s)\n", dma_direction, schedule);
            ret = -2;
        }
    } else {
        printf("Incorrect input Correct schedule: basic, fold, async, split, split_async\n");
        ret = -1;
    }

    if (ret != 0) {
        printf("pipeline failed! %d\n", ret);
    } else {
        // verify result by comparing to expected values
        int error_count = 0;
        for (int y = 0; y < (3 * height) / 2; y++) {
            for (int x = 0; x < width; x++) {
                T correct = data_in[x + y * width] * 2;
                T result = (!strcmp(dma_direction, "rw")) ? data_out[x + y * width] : output(x, y);
                if (correct != result) {
                    printf("Mismatch at x=%d y=%d : %d != %d\n", x, y, correct, result);
                    if (++error_count > 20) abort();
                }
            }
        }
        printf("Success!\n");
    }

    // DMA_step 4: Buffer is processed, disassociate buffer from DMA engine
    //             Optional goto DMA_step 0 for processing more buffers
    halide_hexagon_dma_unprepare(nullptr, input_y);
    halide_hexagon_dma_unprepare(nullptr, input_uv);

    if (!strcmp(dma_direction, "rw")) {
        halide_hexagon_dma_unprepare(nullptr, output_y);
        halide_hexagon_dma_unprepare(nullptr, output_uv);
    }

    // DMA_step 5: Processing is completed and ready to exit, deallocate the DMA engine
    halide_hexagon_dma_deallocate_engine(nullptr, dma_engine);

    if ((!strcmp(schedule, "async") || !strcmp(schedule, "split_async")) && !strcmp(dma_direction, "rw")) {
        halide_hexagon_dma_deallocate_engine(nullptr, dma_engine_write);
    }

    free(data_in);
    free(data_out);

    return ret;
}

int main(int argc, char **argv) {
    int ret = 0;

    if (argc < 5) {
        printf("Usage: %s width height schedule {basic, fold, async, split, split_async} dma_direction {ro, rw} yuv_type {nv12, p010}\n", argv[0]);
        return ret;
    }

    const int width = atoi(argv[1]);
    const int height = atoi(argv[2]);
    const char *schedule = argv[3];
    const char *dma_direction = argv[4];
    const char *yuv_type = argv[5];

    if (!strcmp(yuv_type, "p010")) {
        uint16_t type = 0;
        ret = process_pipeline(type, width, height, schedule, dma_direction, schedule_listP010);
    } else {
        uint8_t type = 0;
        ret = process_pipeline(type, width, height, schedule, dma_direction, schedule_listNV12);
    }

    return ret;
}
